{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [],
      "gpuType": "T4",
      "authorship_tag": "ABX9TyOc3hLRZfvb380DcdrQGPtU",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yerfor/GeneFacePlusPlus/blob/main/inference/genefacepp_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Check GPU"
      ],
      "metadata": {
        "id": "QS04K9oO21AW"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1ESQRDb-yVUG"
      },
      "outputs": [],
      "source": [
        "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Installation"
      ],
      "metadata": {
        "id": "y-ctmIvu3Ei8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# install pytorch3d, about 15s\n",
        "import os\n",
        "import sys\n",
        "import torch\n",
        "need_pytorch3d=False\n",
        "try:\n",
        "    import pytorch3d\n",
        "except ModuleNotFoundError:\n",
        "    need_pytorch3d=True\n",
        "if need_pytorch3d:\n",
        "    if torch.__version__.startswith(\"2.1.\") and sys.platform.startswith(\"linux\"):\n",
        "        # We try to install PyTorch3D via a released wheel.\n",
        "        pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
        "        version_str=\"\".join([\n",
        "            f\"py3{sys.version_info.minor}_cu\",\n",
        "            torch.version.cuda.replace(\".\",\"\"),\n",
        "            f\"_pyt{pyt_version_str}\"\n",
        "        ])\n",
        "        !pip install fvcore iopath\n",
        "        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
        "    else:\n",
        "        # We try to install PyTorch3D from source.\n",
        "        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
      ],
      "metadata": {
        "id": "gXu76wdDgaxo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# install dependencies, about 5~10 min\n",
        "!pip install tensorboard==2.13.0 tensorboardX==2.6.1\n",
        "!pip install pyspy==0.1.1\n",
        "!pip install protobuf==3.20.3\n",
        "!pip install scipy==1.9.1\n",
        "!pip install kornia==0.5.0\n",
        "!pip install trimesh==3.22.0\n",
        "!pip install einops==0.6.1 torchshow==0.5.1\n",
        "!pip install imageio==2.31.1 imageio-ffmpeg==0.4.8\n",
        "!pip install scikit-learn==1.3.0 scikit-image==0.21.0\n",
        "!pip install av==10.0.0 lpips==0.1.4\n",
        "!pip install timm==0.9.2 librosa==0.9.2\n",
        "!pip install openmim==0.3.9\n",
        "!mim install mmcv==2.1.0 # use mim to speed up installation for mmcv\n",
        "!pip install transformers==4.33.2\n",
        "!pip install pretrainedmodels==0.7.4\n",
        "!pip install ninja==1.11.1\n",
        "!pip install faiss-cpu==1.7.4\n",
        "!pip install praat-parselmouth==0.4.3 moviepy==1.0.3\n",
        "!pip install mediapipe==0.10.7\n",
        "!pip install --upgrade attr\n",
        "!pip install beartype==0.16.4 gateloop_transformer==0.4.0\n",
        "!pip install torchode==0.2.0 torchdiffeq==0.2.3\n",
        "!pip install hydra-core==1.3.2 pandas==2.1.3\n",
        "!pip install pytorch_lightning==2.1.2\n",
        "!pip install httpx==0.23.3\n",
        "!pip install gradio==4.16.0\n",
        "!pip install gdown\n",
        "!pip install pyloudnorm webrtcvad pyworld==0.2.1rc0 pypinyin==0.42.0\n",
        "!pip install PyMCubes\n"
      ],
      "metadata": {
        "id": "DuUynxmotG_-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# RESTART kernel to make sure runtime is correct if you meet runtime errors\n",
        "# import os\n",
        "# os.kill(os.getpid(), 9)"
      ],
      "metadata": {
        "id": "0GLEV0HVu8rj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Clone code and download checkpoints"
      ],
      "metadata": {
        "id": "5UfKHKrH6kcq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# clone GeneFacePlusPlus repo from github\n",
        "!git clone https://github.com/yerfor/GeneFacePlusPlus\n",
        "%cd GeneFacePlusPlus"
      ],
      "metadata": {
        "id": "-gfRsd9DwIgl"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# replacing all 'c++14' with 'c++17' for compatible in CUDA 12\n",
        "# special thanks to @yyheart in https://github.com/yerfor/GeneFacePlusPlus/issues/12\n",
        "import glob\n",
        "all_files = glob.glob(\"modules/radnerfs/*.py\") + glob.glob(\"modules/radnerfs/*/*.py\") + glob.glob(\"modules/radnerfs/*/*/*.py\") + glob.glob(\"modules/radnerfs/*/*/*/*.py\")\n",
        "for file in all_files:\n",
        "  with open(file) as f:\n",
        "    lines = f.readlines()\n",
        "    lines = [line.replace('-std=c++14', '-std=c++17') for line in lines]\n",
        "    lines = [line.replace('\\'-use_fast_math\\'', '') for line in lines]\n",
        "    if 'backend' in file:\n",
        "      print(lines)\n",
        "  with open(file, 'w') as f:\n",
        "    f.writelines(lines)\n",
        "\n",
        "# build extensions, it takes a long time (about 10 min)\n",
        "!bash docs/prepare_env/install_ext.sh"
      ],
      "metadata": {
        "id": "yEK4smIUN0hI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# download pretrained ckpts & third-party ckpts from google drive, about 1 min\n",
        "\n",
        "%mkdir -p data/binary/videos/May\n",
        "%cd data/binary/videos/May\n",
        "!gdown https://drive.google.com/uc?id=16fNJz5MbOMqHYHxcK_nPP4EPBXWjugR0\n",
        "%cd ../../../..\n",
        "\n",
        "%cd deep_3drecon/BFM\n",
        "!gdown https://drive.google.com/uc?id=1SPM3IHsyNAaVMwqZZGV6QVaV7I2Hly0v\n",
        "!gdown https://drive.google.com/uc?id=1MSldX9UChKEb3AXLVTPzZQcsbGD4VmGF\n",
        "!gdown https://drive.google.com/uc?id=180ciTvm16peWrcpl4DOekT9eUQ-lJfMU\n",
        "!gdown https://drive.google.com/uc?id=1KX9MyGueFB3M-X0Ss152x_johyTXHTfU\n",
        "!gdown https://drive.google.com/uc?id=19-NyZn_I0_mkF-F5GPyFMwQJ_-WecZIL\n",
        "!gdown https://drive.google.com/uc?id=11ouQ7Wr2I-JKStp2Fd1afedmWeuifhof\n",
        "!gdown https://drive.google.com/uc?id=18ICIvQoKX-7feYWP61RbpppzDuYTptCq\n",
        "!gdown https://drive.google.com/uc?id=1VktuY46m0v_n_d4nvOupauJkK4LF6mHE\n",
        "%cd ../..\n",
        "\n",
        "%mkdir -p checkpoints\n",
        "%cd checkpoints\n",
        "!gdown https://drive.google.com/uc?id=1O5C1vK4yqguOhgRQ7kmYqa3-E8q5H_65\n",
        "!unzip motion2video_nerf.zip\n",
        "%mkdir audio2motion_vae\n",
        "%cd audio2motion_vae\n",
        "!gdown https://drive.google.com/uc?id=1Qg5V-1-IyEgAOxb2PbBjHpYkizuy6njf\n",
        "!gdown https://drive.google.com/uc?id=1bKY5rn3vcAkv-2m1mui0qr4Fs38jEmy-\n",
        "%cd ..\n",
        "\n",
        "%cd ..\n"
      ],
      "metadata": {
        "id": "Yju8dQY7x5OS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Inference sample"
      ],
      "metadata": {
        "id": "LHzLro206pnA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# sample inference, about 3 min\n",
        "!python inference/genefacepp_infer.py \\\n",
        "--a2m_ckpt=checkpoints/audio2motion_vae \\\n",
        "--head_ckpt= \\\n",
        "--torso_ckpt=checkpoints/motion2video_nerf/may_torso \\\n",
        "--drv_aud=data/raw/val_wavs/MacronSpeech.wav \\\n",
        "--out_name=may_demo.mp4 \\\n",
        "--low_memory_usage"
      ],
      "metadata": {
        "id": "2aCDwxNivQoS"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Display output video"
      ],
      "metadata": {
        "id": "XL0c54l19mBG"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# borrow code from makeittalk\n",
        "from IPython.display import HTML\n",
        "from base64 import b64encode\n",
        "import os, sys\n",
        "import glob\n",
        "\n",
        "mp4_name = './may_demo.mp4'\n",
        "\n",
        "mp4 = open('{}'.format(mp4_name),'rb').read()\n",
        "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
        "\n",
        "print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
        "display(HTML(\"\"\"\n",
        "  <video width=256 controls>\n",
        "        <source src=\"%s\" type=\"video/mp4\">\n",
        "  </video>\n",
        "  \"\"\" % data_url))"
      ],
      "metadata": {
        "id": "6olmWwZP9Icj"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}